from mynumpy import *
import lhsmdu

def iid(ndims,nsamps):
    return rand(ndims,nsamps)

def getbins(ndims,nsamps):
    # assume one sample per bin
    # need that bins**ndims = nsamps
    bins = nsamps**(1/ndims)
    assert(np.allclose(bins,np.round(bins))) # must have an integer number of bins
    return bins

def stratified(ndims,nsamps):
    bins = getbins(ndims,nsamps)

    assert(ndims==2)
    start_x,start_y = meshgrid(arange(0,1,1/bins),arange(0,1,1/bins))
    return vstack([start_x.ravel(),start_y.ravel()]) + rand(ndims,nsamps)*1/bins

def qmc(ndims,nsamps):
    bins = getbins(ndims,nsamps)
    start_x,start_y = meshgrid(arange(0,1,1/bins),arange(0,1,1/bins))
    base = vstack([start_x.ravel(),start_y.ravel()])
    shift = expand_dims(array([rand(),rand()]),axis=1)/bins
    return base + shift

def anti(ndims,nsamps):
    assert(ndims==2)
    # get multiple

    # ignores nsamps
    base1 = rand(ndims,nsamps//4)
    base2 = array([base1[0,:],1-base1[1,:]])
    base3 = array([1-base1[0,:],base1[1,:]])
    base4 = 1-base1
    print(hstack([base1,base2,base3,base4]))
    #return hstack([base1,base2,base3,base4])

    result = zeros((ndims,4*(nsamps//4)))
    result[:,0::4] = base1
    result[:,1::4] = base2
    result[:,2::4] = base3
    result[:,3::4] = base4
    return result

def anti_both(ndims,nsamps):
    assert(ndims==2)
    # get multiple

    # ignores nsamps
    base1 = rand(ndims,nsamps//2)
    base2 = 1-base1

    result = zeros((ndims,2*(nsamps//2)))
    result[:,0::2] = base1
    result[:,1::2] = base2
    return result

def anti_1st(ndims,nsamps):
    assert(ndims==2)
    # get multiple

    # # ignores nsamps
    # base1 = rand(ndims,nsamps//2)
    # base2 = base1+0.0
    # #base2[0,:] = 1-base2[0,:]
    # base2[0,:] = np.mod(base2[0,:]+.5,1)

    # result = zeros((ndims,2*(nsamps//2)))
    # result[:,0::2] = base1
    # result[:,1::2] = base2
    # print('anti result', result)

    result = rand(ndims,2*(nsamps//2))
    result[0,1::2] = np.mod(result[0,0::2]+.5,1)
    result[1,1::2] = result[1,0::2]

    return result


def latin(ndims,nsamps):
    k = array(lhsmdu.sample(ndims, nsamps))
    k = k.T[k.T[:,1].argsort()].T
    return k